{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第七章 K近邻算法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1.K近邻算法简单代码演示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>原始样本</th>\n",
       "      <th>酒精含量(%)</th>\n",
       "      <th>苹果酸含量(%)</th>\n",
       "      <th>分类</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>样本1</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>样本2</td>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>样本3</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>样本4</td>\n",
       "      <td>8</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>样本5</td>\n",
       "      <td>10</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  原始样本  酒精含量(%)  苹果酸含量(%)  分类\n",
       "0  样本1        5         2   0\n",
       "1  样本2        6         1   0\n",
       "2  样本3        4         1   0\n",
       "3  样本4        8         3   1\n",
       "4  样本5       10         2   1"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_excel('葡萄酒.xlsx')\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 特征变量和目标变量的切分\n",
    "X_train = df[['酒精含量(%)','苹果酸含量(%)']]\n",
    "y_train = df['分类']  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=None, n_neighbors=3, p=2,\n",
       "           weights='uniform')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 模型训练\n",
    "from sklearn.neighbors import KNeighborsClassifier as KNN\n",
    "knn = KNN(n_neighbors=3)  \n",
    "knn.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\n"
     ]
    }
   ],
   "source": [
    "# 模型预测：预测单个样本\n",
    "X_test = [[7, 1]]  # X_test为测试集特征变量\n",
    "answer = knn.predict(X_test)  \n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 1]\n"
     ]
    }
   ],
   "source": [
    "# 模型预测：预测多个样本\n",
    "X_test = [[7, 1], [8, 3]]  # 这里能帮助理解为什么要写成二维数组的样式\n",
    "answer = knn.predict(X_test)  \n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "补充知识点：K近邻算法回归模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2.5]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.neighbors import KNeighborsRegressor\n",
    "X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n",
    "y = [1, 2, 3, 4, 5]\n",
    "\n",
    "model = KNeighborsRegressor(n_neighbors=2)\n",
    "model.fit(X, y)\n",
    "\n",
    "print(model.predict([[5, 5]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.数据归一化代码演示"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2.1 min-max标准化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_excel('葡萄酒2.xlsx')\n",
    "X = df[['酒精含量(%)','苹果酸含量(%)']]\n",
    "y = df['分类']  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\sklearn\\preprocessing\\data.py:334: DataConversionWarning: Data with input dtype int64 were all converted to float64 by MinMaxScaler.\n",
      "  return self.partial_fit(X, y)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0.16666667, 0.5       ],\n",
       "       [0.33333333, 0.        ],\n",
       "       [0.        , 0.        ],\n",
       "       [0.66666667, 1.        ],\n",
       "       [1.        , 0.5       ]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.preprocessing import MinMaxScaler\n",
    "X_new = MinMaxScaler().fit_transform(X)\n",
    "X_new"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2.2 Z-score标准化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\sklearn\\preprocessing\\data.py:645: DataConversionWarning: Data with input dtype int64 were all converted to float64 by StandardScaler.\n",
      "  return self.partial_fit(X, y)\n",
      "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\sklearn\\base.py:464: DataConversionWarning: Data with input dtype int64 were all converted to float64 by StandardScaler.\n",
      "  return self.fit(X, **fit_params).transform(X)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-0.74278135,  0.26726124],\n",
       "       [-0.27854301, -1.06904497],\n",
       "       [-1.2070197 , -1.06904497],\n",
       "       [ 0.64993368,  1.60356745],\n",
       "       [ 1.57841037,  0.26726124]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "X_new = StandardScaler().fit_transform(X)\n",
    "X_new"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3.案例实战 - 手写数字识别模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>对应数字</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>...</th>\n",
       "      <th>1014</th>\n",
       "      <th>1015</th>\n",
       "      <th>1016</th>\n",
       "      <th>1017</th>\n",
       "      <th>1018</th>\n",
       "      <th>1019</th>\n",
       "      <th>1020</th>\n",
       "      <th>1021</th>\n",
       "      <th>1022</th>\n",
       "      <th>1023</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 1025 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   对应数字  0  1  2  3  4  5  6  7  8  ...  1014  1015  1016  1017  1018  1019  \\\n",
       "0     0  0  0  0  0  0  0  0  0  0  ...     0     0     0     0     0     0   \n",
       "1     0  0  0  0  0  0  0  0  0  0  ...     0     0     0     0     0     0   \n",
       "2     0  0  0  0  0  0  0  0  0  0  ...     0     0     0     0     0     0   \n",
       "3     0  0  0  0  0  0  0  0  0  0  ...     0     0     0     0     0     0   \n",
       "4     0  0  0  0  0  0  0  0  0  0  ...     0     0     0     0     0     0   \n",
       "\n",
       "   1020  1021  1022  1023  \n",
       "0     0     0     0     0  \n",
       "1     0     0     0     0  \n",
       "2     0     0     0     0  \n",
       "3     0     0     0     0  \n",
       "4     0     0     0     0  \n",
       "\n",
       "[5 rows x 1025 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 1.读取数据\n",
    "import pandas as pd\n",
    "df = pd.read_excel('手写字体识别.xlsx')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2.提取特征变量和目标变\n",
    "X = df.drop(columns='对应数字') \n",
    "y = df['对应数字']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3.划分训练集和测试集\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=None, n_neighbors=5, p=2,\n",
       "           weights='uniform')"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 4.模型搭建\n",
    "from sklearn.neighbors import KNeighborsClassifier as KNN\n",
    "knn = KNN(n_neighbors=5) \n",
    "knn.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5 3 7 8 9 2 1 4 5 8 9 5 9 3 3 2 3 7 9 1 0 0 7 6 6 7 0 9 6 9 1 8 6 9 2 5 2\n",
      " 4 5 8 3 6 9 4 9 2 7 3 4 9 5 6 7 3 3 8 3 1 5 3 6 7 5 0 3 7 1 4 9 1 5 1 2 6\n",
      " 9 1 9 5 5 9 2 8 8 4 4 9 4 3 9 8 0 3 4 3 6 8 5 2 9 0]\n"
     ]
    }
   ],
   "source": [
    "# 5.模型预测 - 预测数据结果\n",
    "y_pred = knn.predict(X_test)\n",
    "print(y_pred[0:100])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>预测值</th>\n",
       "      <th>实际值</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>8</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>9</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   预测值  实际值\n",
       "0    5    5\n",
       "1    3    3\n",
       "2    7    7\n",
       "3    8    8\n",
       "4    9    9"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = pd.DataFrame()  # 创建一个空DataFrame \n",
    "a['预测值'] = list(y_pred)\n",
    "a['实际值'] = list(y_test)\n",
    "a.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.979328165374677"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 预测准确度评估\n",
    "from sklearn.metrics import accuracy_score\n",
    "score = accuracy_score(y_pred, y_test)\n",
    "score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.979328165374677"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 模型自带的score()函数也可以进行打分\n",
    "score = knn.score(X_test, y_test)\n",
    "score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.补充知识点：图像识别原理详解"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4.1 图片大小调整及显示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAA1ElEQVR4nGP8//8/Ay0BE01NH7EWvDYy+pSXT6RiRpIi+bWREZqI6Llz+LWQ4ANM04kBlMYBQVuJtYA855NgAdmAKAuQnU8wVkm24L2uHkkmkmzBH1YWhOpfv6hsAVrcCl+5QmULkAEbWcUuPgvQnM9/nrToJWwBVQBOC9CcT2rqJGwBtQB2C6jlfAasxfVHe/tfnz+TahC7qCjfzp2Y4lh8QIbpDAwMP1+/xio+EHHATHp5wIA7nkirMiGApCQwOOqDUQtIBcixSjCTk5OKSAI0DyIA/6VDQJqfHnAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=32x32 at 0x202B424C470>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from PIL import Image\n",
    "img = Image.open('数字4.png')\n",
    "img = img.resize((32,32))\n",
    "img.show()\n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4.2 图片灰度处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAAAAABWESUoAAAAhElEQVR4nMWTMQ6AIAxFH9zAU5runsEbKIOHdHLT6IQJVIoMxm7k//y+FnAndvmK/olBpuToMkgBIBQT5CWDlAw6oHEKIeFThr6W4IHdMAjAYjKAM1oIwGwyPNZtENAztiSUAuJ1D1sudGOSoHTWRoZDCZEnf3IK95d/UTUEkpXqKVpbXL7yFn/67SxYAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=32x32 at 0x202AF11C048>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img = img.convert('L')\n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4.3 图片二值化处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n"
     ]
    }
   ],
   "source": [
    "# 二值化处理\n",
    "import numpy as np\n",
    "img_new = img.point(lambda x: 0 if x > 128 else 1)\n",
    "arr = np.array(img_new)\n",
    "\n",
    "# 打印arr中的每一行\n",
    "for i in range(arr.shape[0]):\n",
    "    print(arr[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4.4 二维数组转一维数组"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "arr_new = arr.reshape(1, -1)\n",
    "arr_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1024)\n"
     ]
    }
   ],
   "source": [
    "print(arr_new.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此时我们可以把这个处理过的图片“数字4”传入到我们上面训练好的knn模型中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "图片中的数字为：4\n"
     ]
    }
   ],
   "source": [
    "answer = knn.predict(arr_new) \n",
    "print('图片中的数字为：' + str(answer[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.代码汇总与测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "图片中的数字为：1\n"
     ]
    }
   ],
   "source": [
    "# 主要分为三步，第一步训练模型，第二步处理图片，第三步导入模型并预测\n",
    "# 1.训练模型\n",
    "# 1.1 读取数据\n",
    "import pandas as pd\n",
    "df = pd.read_excel('手写字体识别.xlsx')\n",
    "\n",
    "# 1.2 提取特征变量和目标变量\n",
    "X = df.drop(columns='对应数字') \n",
    "y = df['对应数字']   \n",
    "\n",
    "# 1.3 划分训练集和测试集\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)\n",
    "\n",
    "# 1.4 训练模型\n",
    "from sklearn.neighbors import KNeighborsClassifier as KNN\n",
    "knn = KNN(n_neighbors=5) \n",
    "knn.fit(X_train, y_train)  # 自此手写字体识别模型便搭建好了\n",
    "\n",
    "# 2.处理图片\n",
    "# 2.1 图片读取 & 大小调整 & 灰度处理\n",
    "from PIL import Image\n",
    "img = Image.open('测试图片.png')  # 这里传入手写的图片，注意写对文件路径\n",
    "img = img.resize((32,32))\n",
    "img = img.convert('L')\n",
    "\n",
    "# 2.2 图片二值化处理 & 二维数据转一维数据\n",
    "import numpy as np\n",
    "img_new = img.point(lambda x: 0 if x > 128 else 1)\n",
    "arr = np.array(img_new)\n",
    "arr_new = arr.reshape(1, -1)\n",
    "\n",
    "# 3.预测手写数字\n",
    "answer = knn.predict(arr_new) \n",
    "print('图片中的数字为：' + str(answer[0]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
